from datetime import datetime
from typing import  Dict
from vnpy.trader.constant import Exchange, Interval
from vnpy2.trader.database import database_manager
from vnpy.app.cta_strategy import ArrayManager
from vnpy.chart import ChartWidget, VolumeItem, CandleItem
import pyqtgraph as pg
from vnpy.trader.ui import create_qapp, QtCore, QtGui
from vnpy.trader.object import BarData
from vnpy.chart.manager import BarManager


class ZB(CandleItem):
    """自定义指标显示"""

    def __init__(self, manager: BarManager):
        """"""
        super().__init__(manager)

        self.blue_pen: QtGui.QPen = pg.mkPen(color=(100, 100, 255), width=2)
        self.sma_data: Dict[int, float] = {}

    def get_sma_value(self, ix: int) -> float:
        """"""
        if ix < 0:
            return 0

        if not self.sma_data:
            bars = self._manager.get_all_bars()
            sma_array = [bar.down_line for bar in bars]

            for n, value in enumerate(sma_array):
                self.sma_data[n] = value

        if ix in self.sma_data:
            return self.sma_data[ix]

        sma_value = sma_array[-1]

        return sma_value

    def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
        """"""
        sma_value = self.get_sma_value(ix)
        last_sma_value = self.get_sma_value(ix - 1)

        # Create objects
        picture = QtGui.QPicture()
        painter = QtGui.QPainter(picture)

        # Set painter color
        painter.setPen(self.blue_pen)

        # Draw Line
        start_point = QtCore.QPointF(ix-1, last_sma_value)
        end_point = QtCore.QPointF(ix, sma_value)
        painter.drawLine(start_point, end_point)

        # Finish
        painter.end()
        return picture

    def get_info_text(self, ix: int) -> str:
        """"""
        if ix in self.sma_data:
            sma_value = self.sma_data[ix]
            text = f"ZB {sma_value:.2f}"
        else:
            text = "ZB -"

        return text

class ZB2(CandleItem):
    """自定义指标显示"""

    def __init__(self, manager: BarManager):
        """"""
        super().__init__(manager)

        self.blue_pen: QtGui.QPen = pg.mkPen(color=(100, 100, 255), width=2)
        self.sma_data: Dict[int, float] = {}

    def get_sma_value(self, ix: int) -> float:
        """"""
        if ix < 0:
            return 0

        if not self.sma_data:
            bars = self._manager.get_all_bars()
            sma_array = [bar.up_line for bar in bars]

            for n, value in enumerate(sma_array):
                self.sma_data[n] = value

        if ix in self.sma_data:
            return self.sma_data[ix]

        sma_value = sma_array[-1]

        return sma_value

    def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
        """"""
        sma_value = self.get_sma_value(ix)
        last_sma_value = self.get_sma_value(ix - 1)

        # Create objects
        picture = QtGui.QPicture()
        painter = QtGui.QPainter(picture)

        # Set painter color
        painter.setPen(self.blue_pen)

        # Draw Line
        start_point = QtCore.QPointF(ix-1, last_sma_value)
        end_point = QtCore.QPointF(ix, sma_value)
        painter.drawLine(start_point, end_point)

        # Finish
        painter.end()
        return picture

    def get_info_text(self, ix: int) -> str:
        """"""
        if ix in self.sma_data:
            sma_value = self.sma_data[ix]
            text = f"ZB {sma_value:.2f}"
        else:
            text = "ZB -"

        return text


if __name__ == "__main__":
    app = create_qapp()

    symbol = "CL-20210322-USD-FUT"
    exchange = Exchange.NYMEX
    interval = Interval.MINUTE_30
    start = datetime(2021, 1, 1)
    end = datetime(2022, 1, 1)

    bars = database_manager.load_bar_data(
        symbol=symbol,
        exchange=exchange,
        interval=interval,
        start=start,
        end=end
    )

    am = ArrayManager(50)
    new_data = bars[:]
    line_up = []
    line_down = []

    while new_data :
        bar = new_data.pop(0)
        am.update_bar(bar)
        up, down = am.boll(20,2)
        line_up.append(up)
        line_down.append(down)   #这里调用合适的公式就好了

    print("K线数量是", len(bars), "指标数据是", len(line_up))

    i = 0
    while line_down :
        bars[i].down_line = line_down.pop(0)
        bars[i].up_line = line_up.pop(0)
        i = i + 1
    print("共处理了", i, "数据")

    widget = ChartWidget()
    widget.add_plot("candle", hide_x_axis=True)
    widget.add_plot("volume", maximum_height=250)
    widget.add_item(CandleItem, "candle", "candle")
    widget.add_item(VolumeItem, "volume", "volume")

    widget.add_item(ZB, "ZB", "candle")
    widget.add_item(ZB2, "ZB2", "candle")
    widget.add_cursor()

    history = bars
    widget.update_history(history)

    def update_bar():
        bar = new_data.pop(0)
        widget.update_bar(bar)

    timer = QtCore.QTimer()
    timer.timeout.connect(update_bar)

    widget.show()
    app.exec_()